#!/usr/bin/env python3
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0

import curses
import json
import os
import re
import shlex
import shutil
import signal
import subprocess
import sys
import threading
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import psutil
from huggingface_hub import HFValidationError, LocalEntryNotFoundError, snapshot_download

from models.tt_transformers.tt.model_config import parse_optimizations

# default setting: precision_cfg = {ff1_3: bfp8, ff2: bfp8, wqkv: bfp8, wo: bfp8, kv_cache: bfp8, activation: mixed}, fidelity_cfg = {li_ff1_3: hifi2, li_ff2: hifi2, li_qkv_decode: hifi2, sdpa_decode: hifi2na, li_o_decode: hifi2, li_qkv_prefill: hifi2, sdpa_prefill: hifi4, li_o_prefill: hifi2_fp16}
# NOTE: the following settings are used to override the default settings for Pareto analysis
dtype_mf_settings = [
    "precision_cfg = {ff1_3: bfp4, ff2: bfp4}, fidelity_cfg = {li_ff1_3: lofi, li_ff2: lofi}",
    "precision_cfg = {ff1_3: bfp4, ff2: bfp8}, fidelity_cfg = {li_ff1_3: lofi, li_ff2: hifi2}",
    "precision_cfg = {ff1_3: bfp4, ff2: bf16}, fidelity_cfg = {li_ff1_3: lofi, li_ff2: hifi4}",
    "precision_cfg = {ff1_3: bfp8, ff2: bfp4}, fidelity_cfg = {li_ff1_3: hifi2, li_ff2: lofi}",
    "precision_cfg = {ff1_3: bfp8, ff2: bfp8}, fidelity_cfg = {li_ff1_3: hifi2, li_ff2: hifi2}",
    "precision_cfg = {ff1_3: bfp8, ff2: bf16}, fidelity_cfg = {li_ff1_3: hifi2, li_ff2: hifi4}",
    "precision_cfg = {ff1_3: bf16, ff2: bfp4}, fidelity_cfg = {li_ff1_3: hifi4, li_ff2: lofi}",
    "precision_cfg = {ff1_3: bf16, ff2: bfp8}, fidelity_cfg = {li_ff1_3: hifi4, li_ff2: hifi2}",
    "precision_cfg = {ff1_3: bf16, ff2: bf16}, fidelity_cfg = {li_ff1_3: hifi4, li_ff2: hifi4}",
    "precision_cfg = {ff1_3: bfp8, ff2: bfp8, wqkv: bfp8, wo: bfp4, kv_cache: bfp8, activation: bf16}, fidelity_cfg = {li_ff1_3: hifi2, sdpa_decode: hifi2na, li_o_decode: lofi, li_qkv_prefill: hifi2, sdpa_prefill: hifi4, li_o_prefill: lofi}",
    "precision_cfg = {ff1_3: bfp8, ff2: bfp8, wqkv: bfp8, wo: bfp4, kv_cache: bfp8, activation: bfp8}, fidelity_cfg = {li_ff1_3: hifi2, sdpa_decode: hifi2na, li_o_decode: lofi, li_qkv_prefill: hifi2, sdpa_prefill: hifi4, li_o_prefill: lofi}",
    "precision_cfg = {ff1_3: bfp8, ff2: bfp8, wqkv: bfp8, wo: bfp4, kv_cache: bfp8}, fidelity_cfg = {li_ff1_3: hifi2, li_ff2: hifi2, li_o_decode: lofi, li_qkv_prefill: hifi2, sdpa_prefill: hifi4, li_o_prefill: lofi}",
]

pareto_commands = {
    k: v
    for cmd in [
        {
            # dmf is short for "dtype and math fidelity"
            f"accuracy-dmf-{i}": f"pytest models/tt_transformers/demo/simple_text_demo.py -k 'accuracy and ci-token-accuracy' --optimizations '{setting}'",
            f"demo-dmf-{i}": f"pytest models/tt_transformers/demo/simple_text_demo.py -k 'params0-performance and batch-1' --optimizations '{setting}'",
        }
        for i, setting in enumerate(dtype_mf_settings)
    ]
    for k, v in cmd.items()
}

DECODER_CONFIG_FOLDER = Path("models/tt_transformers/tests/configurations/")
decoder_config_files = [file for file in DECODER_CONFIG_FOLDER.glob("*.json")]

pareto_commands_from_json = {
    k: v
    for i, file in enumerate(decoder_config_files)
    for k, v in {
        f"accuracy-dmf-json-{i}": f"pytest models/tt_transformers/demo/simple_text_demo.py -k 'accuracy and ci-token-accuracy' --decoder_config_file {file}",
        f"demo-dmf-json-{i}": f"pytest models/tt_transformers/demo/simple_text_demo.py -k 'params0 and batch-1' --decoder_config_file {file}",
    }.items()
}


def ensure_less_installed():
    if shutil.which("less") is None:
        print("'less' command not found. Installing...")
        try:
            subprocess.run(["sudo", "apt", "update"], check=True)
            subprocess.run(["sudo", "apt", "install", "-y", "less"], check=True)
            print("'less' has been successfully installed.")
        except subprocess.CalledProcessError as e:
            print(f"Error installing 'less': {e}")
            sys.exit(1)


def ensure_ttsmi_installed():
    if shutil.which("tt-smi") is None:
        print("'tt-smi' command not found. Installing...")
        try:
            subprocess.run(["pip", "install", "git+https://github.com/tenstorrent/tt-smi.git"], check=True)
            print("'tt-smi' has been successfully installed.")
        except subprocess.CalledProcessError as e:
            print(f"Error installing 'tt-smi': {e}")
            sys.exit(1)

    # Run tt-smi to generate the reset config file
    hostname = os.environ.get("HOSTNAME", "unknown")
    config_dir = os.path.expanduser("~/.config/tenstorrent")
    os.makedirs(config_dir, exist_ok=True)
    config_file = os.path.join(config_dir, f"reset_config_{hostname}.json")

    device = get_device()
    try:
        if device == "TG":
            generate_tg_config_file()
        else:
            subprocess.run(["tt-smi", "-g", config_file], check=True)
        print(f"Generated reset config file: {config_file}")
    except subprocess.CalledProcessError as e:
        print(f"Error generating reset config file: {e}")
        sys.exit(1)

    if device == "TG":  # Reset TG device on program start
        print("Resetting device on program start...")
        reset_device_sync(config_file)


def reset_device_sync(config_file):
    if os.environ.get("RESET_CMD"):
        reset_cmd = os.environ.get("RESET_CMD").split(" ")
        print(f"Resetting device using custom command: {reset_cmd}")
    else:
        reset_cmd = ["tt-smi", "-r", config_file]
        print(f"Resetting device using config file: {config_file}")
    try:
        result = subprocess.run(reset_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        print(f"Device reset successfully: {result.stdout}")
    except subprocess.CalledProcessError as e:
        print(f"Error during device reset: {e.stdout} {e.stderr}")
        sys.exit(1)


def get_device():
    smi_ls = ["tt-smi", "-ls"]
    smi_ls_output = subprocess.run(smi_ls, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout
    total_devices = count_devices(smi_ls_output)

    if total_devices == 1:
        device = "N150" if "Wormhole" in smi_ls_output else "P150"  # Blackhole
    elif total_devices == 2:
        device = "N300" if "Wormhole" in smi_ls_output else "P300"  # Blackhole
    elif total_devices == 8:
        device = "T3K"
    else:  # TG has 36 devices
        device = "TG"

    # Old method of getting device name based on hostname
    # hostname = os.environ.get("HOSTNAME", "unknown")
    # if "t30" in hostname:
    #     device = "T3K"
    # elif "glx" in hostname:
    #     device = "TG"

    return device


def list_supported_devices(device):
    if device == "TG":
        return "n150, n300, t3k, tg"
    elif device == "T3K":
        return "n150, n300, t3k"
    elif device == "N300":
        return "n150, n300"
    elif device == "P300":
        return "p150, p300"
    elif device == "P150":
        return "p150"
    else:
        return "n150"


# Counts number of devices using `tt-smi -ls` output
def count_devices(output):
    # Split the output into available boards section
    sections = output.split("All available boards on host")
    available_boards = sections[1].split("Boards that can be reset")[0]

    # Count total PCI devices (ignoring N/A)
    total_pci_devices = len(
        [line for line in available_boards.split("\n") if any(word in line for word in ["Blackhole", "Wormhole"])]
    )

    return total_pci_devices


class OutputEntryList:
    def __init__(self):
        self._entries = []
        # Create logs directory
        os.makedirs("logs", exist_ok=True)
        # Load existing state
        self._load_state()

    def _load_state(self):
        try:
            with open("logs/state.json", "r") as f:
                state = json.load(f)
                for entry_data in state:
                    entry = Entry(
                        entry_data["command_name"],
                        entry_data["model"],
                        entry_data["device"],
                        entry_data["command_input"],
                    )
                    # Restore saved attributes
                    entry.status = (
                        "Cancelled"
                        if entry_data["status"]
                        in [
                            "Waiting",
                            "Running",
                            "Resetting",
                            "Initializing device",
                            "Starting",
                            "Prefill",
                            "Decode",
                            "Terminating",
                            "Exiting",
                        ]
                        else entry_data["status"]
                    )
                    entry.output = entry_data["output"]
                    entry.log_id = entry_data["log_id"]
                    entry.speed = entry_data["speed"]
                    if (
                        "ttft" not in entry_data.keys()
                    ):  # Verify if the new TTFT attribute is present to avoid errors with old lt versions
                        os.remove("logs/state.json")
                        return
                    entry.ttft = entry_data["ttft"]
                    entry.pcc = entry_data["pcc"]
                    self._entries.append(entry)
        except (FileNotFoundError, json.JSONDecodeError):
            pass

    def save_state(self):
        state = []
        for entry in self._entries:
            entry_data = {
                "command_name": entry.command_name,
                "model": entry.model,
                "device": entry.device,
                "command_input": entry.command_input,
                "status": entry.status,
                "output": entry.output,
                "log_id": entry.log_id,
            }
            if hasattr(entry, "speed"):
                entry_data["speed"] = entry.speed
            if hasattr(entry, "ttft"):
                entry_data["ttft"] = entry.ttft
            if hasattr(entry, "pcc"):
                entry_data["pcc"] = entry.pcc
            state.append(entry_data)

        with open("logs/state.json", "w") as f:
            json.dump(state, f, indent=2)

    def __len__(self):
        return len(self._entries)

    def __getitem__(self, index):
        return self._entries[index]

    def __iter__(self):
        return iter(self._entries)

    def append(self, entry):
        if entry.log_id is None:
            entry.log_id = self.next_log_id()

        # Remove any existing logs with same ID
        for existing_file in os.listdir("logs"):
            if existing_file.startswith(entry.log_prefix):
                os.remove(os.path.join("logs", existing_file))

        # Set parent list reference before appending
        entry.set_parent_list(self)
        self._entries.append(entry)
        self.save_state()

    def pop(self, index):
        result = self._entries.pop(index)

        try:
            os.remove(result.get_log_filename())
        except (OSError, FileNotFoundError):
            pass

        # Mark all subsequent entries as changed
        for entry in self._entries[index:]:
            entry.mark_changed()

        self.save_state()
        return result

    def index(self, entry):
        return self._entries.index(entry)

    def get_entries(self):
        return self._entries

    def next_log_id(self):
        # Fix potential issue with list comprehension on dictionary access
        max_id = 0
        for entry in self._entries:
            if entry.log_id > max_id:
                max_id = entry.log_id
        return max_id + 1


class Entry:
    def __init__(self, command_name, model, device, command_input):
        self.command_name = command_name
        self.model = model
        self.device = device.upper()
        self.command_input = command_input
        self.status = "Waiting"
        self.output = ""
        self.process = None
        self.log_file = None
        self.stop_event = threading.Event()
        self.lock = threading.Lock()
        self.log_id = None  # Will be set by OutputEntryList
        self.speed = None
        self.ttft = None
        self.pcc = None
        self.thread = None
        self.changed = True  # Initialize as changed to ensure first draw
        self._parent_list = None  # Reference to parent OutputEntryList

    @property
    def log_prefix(self):
        """Generate the log file prefix based on the entry's log ID"""
        return f"{self.log_id:04d}-"

    def mark_changed(self):
        self.changed = True

    def mark_drawn(self):
        self.changed = False

    def __setattr__(self, name, value):
        super().__setattr__(name, value)
        # Mark as changed whenever any attribute is modified
        # (except for 'changed' itself to avoid recursion)
        if name != "changed" and hasattr(self, "changed"):
            self.changed = True
        # Save state if we have a parent list and this isn't an unpersisted attribute
        if (
            hasattr(self, "_parent_list")
            and self._parent_list
            and name not in ["process", "log_file", "stop_event", "lock", "thread"]
        ):
            self._parent_list.save_state()

    def __getitem__(self, key):
        # Support dictionary-style access for backward compatibility
        return getattr(self, key)

    def __setitem__(self, key, value):
        # Support dictionary-style assignment for backward compatibility
        setattr(self, key, value)

    def get(self, key, default=None):
        # Support dictionary-style get() for backward compatibility
        return getattr(self, key, default)

    def get_log_filename(self):
        """Generate log filename based on entry properties"""
        command_name = self._get_command_name()
        filename = f"{self.log_prefix}{self.device}-{self.model}-{command_name}.log"
        return os.path.join("logs", filename.replace("/", "_"))

    def _get_command_name(self):
        """Extract command name from command input"""
        if "pytest" in self.command_input:
            match = re.search(r"pytest\s+([\S]+)", self.command_input)
            if match:
                test_file = match.group(1)
                return os.path.basename(test_file).split(".")[0]
            return "pytest"
        return os.path.basename(shlex.split(self.command_input)[0])

    def open_log_file(self):
        """Open and return log file for writing"""
        self.log_file = open(self.get_log_filename(), "w")
        return self.log_file

    def set_parent_list(self, parent_list):
        self._parent_list = parent_list


def main(stdscr):
    curses.curs_set(0)  # Hide cursor
    curses.start_color()
    curses.use_default_colors()

    # Define color pairs using extended colors
    define_color_pairs()

    max_y, max_x = stdscr.getmaxyx()

    host_device = get_device()

    # Input fields positions (reordered)
    input_fields = [
        {"label": "Command [demo]", "value": "", "x": 0, "y": 0},
        {
            "label": "Model (1b, 3b, 8b, 11b, 70b, q2-7b, q2-72b, q2-coder-32b, q3-32b) [all]",
            "value": "",
            "x": 0,
            "y": 1,
        },
        {
            "label": f"Device ({list_supported_devices(host_device)}) [all]",
            "value": "",
            "x": 0,
            "y": 2,
        },
    ]

    output_entries = OutputEntryList()
    current_line = 0  # Index of the current line (input fields + output entries)

    screen_lock = threading.Lock()
    screen_needs_update = threading.Event()  # New event to signal screen updates
    last_drawn_state = {
        "input_fields": [],  # Start with an empty list
        "output_entries": [],
        "current_line": -1,  # Set to an invalid value to force initial draw
        "max_y": max_y,
        "max_x": max_x,
    }

    # Start the worker thread
    worker_stop_event = threading.Event()
    worker_thread = threading.Thread(
        target=worker_thread_func, args=(output_entries, worker_stop_event, screen_lock, screen_needs_update)
    )
    worker_thread.daemon = True
    worker_thread.start()

    stdscr.nodelay(True)  # Set getch() non-blocking

    # Initial draw
    draw_changes(stdscr, input_fields, output_entries, current_line, last_drawn_state)
    stdscr.refresh()

    exiting = False  # New flag to indicate we're in the process of exiting

    # Main loop
    while True:
        new_max_y, new_max_x = stdscr.getmaxyx()
        if new_max_y != last_drawn_state["max_y"] or new_max_x != last_drawn_state["max_x"]:
            stdscr.clear()
            last_drawn_state["max_y"], last_drawn_state["max_x"] = new_max_y, new_max_x
            last_drawn_state["current_line"] = -1  # Reset to force redraw
            screen_needs_update.set()

        if screen_needs_update.is_set():
            with screen_lock:
                # Draw everything
                draw_changes(stdscr, input_fields, output_entries, current_line, last_drawn_state)
                draw_help_bar(stdscr, current_line, len(input_fields), len(output_entries))  # Add this line
                stdscr.refresh()
            screen_needs_update.clear()

        c = stdscr.getch()

        # Check if we should exit after all jobs are done
        if exiting and all(
            entry["status"] in ["Exiting", "Cancelled", "Error", "Finished"] for entry in output_entries
        ):
            # Save state before exiting
            output_entries.save_state()
            return

        if c == -1:
            # No key pressed, continue to next iteration
            time.sleep(0.01)  # Short sleep to prevent high CPU usage
            continue
        elif c == 27:  # Handle escape key press
            if not exiting:
                exiting = True
                worker_stop_event.set()

                # Find the running job and set it to terminate
                running_entry = None
                for entry in output_entries:
                    with entry["lock"]:
                        if entry["process"] and entry["process"].poll() is None:
                            running_entry = entry
                            entry["stop_event"].set()
                            entry["status"] = "Terminating"
                            terminate_process_tree(entry["process"].pid)
                            break

                # Set all other jobs to "Exiting"
                for entry in output_entries:
                    with entry["lock"]:
                        if entry != running_entry and entry["status"] == "Waiting":
                            entry["status"] = "Exiting"

                # Clear input fields
                for field in input_fields:
                    field["value"] = "Exiting"

                screen_needs_update.set()
            else:
                # If escape is pressed again while exiting, force quit
                return
        elif c == curses.KEY_UP:
            current_line = (current_line - 1) % (len(input_fields) + len(output_entries))
            screen_needs_update.set()
        elif c == curses.KEY_DOWN:
            current_line = (current_line + 1) % (len(input_fields) + len(output_entries))
            screen_needs_update.set()
        elif c == curses.KEY_ENTER or c == 10 or c == 13:
            if not exiting:
                if current_line < len(input_fields):
                    # We are in input fields
                    current_field = current_line

                    # If the last field is selected, submit the command
                    if current_field == len(input_fields) - 1:
                        # Submit command
                        command_input = input_fields[0]["value"] or "demo"
                        model_input = input_fields[1]["value"] or "1b,3b,8b,11b,70b,q2-7b,q2-72b,q2-coder-32b,q3-32b"
                        device_input = input_fields[2]["value"] or list_supported_devices(host_device)

                        if command_input == "modules":
                            command_input = "rmsnorm,attention,attention-prefill,mlp,lm-head"

                        if command_input == "tests":
                            command_input = "embedding,rmsnorm,attention,attention-prefill,mlp,lm-head,decoder,decoder-prefill,model,model-prefill"

                        if command_input == "table":
                            command_input = "accuracy,demo,accuracy-acc,demo-acc"

                        if command_input == "vision":
                            command_input = "vision-mlp,vision-attn,vision-block,vision-xfmr,vision-xattn,vision-xblock,vision-conv,vision-class,vision-tile-pos,vision-pos,vision-encoder,vision-text-xfmr,vision-vision-xfmr"

                        if command_input == "pareto":
                            command_input = ",".join(pareto_commands.keys())

                        if command_input == "pareto_from_json":
                            command_input = ",".join(pareto_commands_from_json.keys())

                        # Parse models, devices, and commands
                        models = parse_list(model_input)
                        devices = parse_list(device_input)
                        commands = parse_list(command_input, allow_space=False)

                        # Generate combinations (reordered)
                        # Ignore invalid combinations:
                        # - 11b and 11b-b models on n150/p150 device
                        # - 70b model on n150/n300 & p150/p300 devices
                        # - 72b model on n150/n300 and p150/p300 devices
                        # - q2-7b on anything other than n300/p300
                        # - q2-coder-32b on anything other than t3k/p300
                        # - q3-32b on anything other than t3k/p300
                        # - Vision commands on non-vision (11b) models
                        combinations = [
                            (c, m, d)
                            for c in commands
                            for m in models
                            for d in devices
                            if not (
                                (m in ["11b", "11b-b"] and d in ["n150", "p150"])
                                or (m.startswith("70b") and d in ["n150", "n300", "p150", "p300"])
                                or (m == "q2-72b" and d in ["n150", "n300", "p150", "p300"])
                                or (m == "q2-7b" and d in ["n150", "t3k"])
                                or (m == "q2-coder-32b" and d in ["n150", "n300", "p150"])
                                or (m == "q3-32b" and d in ["n150", "n300", "p150"])
                                or ("vision" in c and m not in ["11b", "11b-b"])
                            )
                        ]

                        # Create output entries
                        for command, model, device in combinations:
                            command_name = get_command_name(command)
                            entry = Entry(command_name, model, device, command)
                            output_entries.append(entry)

                        current_line = 0
                        screen_needs_update.set()
                    else:
                        # Otherwise if not the last field, move to next field
                        total_lines = len(input_fields) + len(output_entries)
                        current_line = (current_line + 1) % total_lines
                        screen_needs_update.set()
                else:
                    # We are in the output entries
                    entry_index = current_line - len(input_fields)
                    if entry_index < len(output_entries):
                        entry = output_entries[entry_index]
                        if os.path.exists(entry.get_log_filename()):
                            # Save current terminal state
                            curses.def_prog_mode()
                            # Exit curses temporarily
                            curses.endwin()
                            # Run less command
                            os.system(f"less -R {entry.get_log_filename()}")
                            # Resume curses
                            curses.reset_prog_mode()
                            stdscr.refresh()
                            screen_needs_update.set()
            else:
                # Ignore enter key when exiting
                continue
        elif c == curses.KEY_BACKSPACE or c == 127 or (c == ord("x") and current_line >= len(input_fields)):
            if current_line < len(input_fields):
                current_field = current_line
                # Remove last character from current field
                if len(input_fields[current_field]["value"]) > 0:
                    input_fields[current_field]["value"] = input_fields[current_field]["value"][:-1]
            else:
                # We are in the output entries
                entry_index = current_line - len(input_fields)
                if entry_index < len(output_entries):
                    entry = output_entries[entry_index]
                    if cancel_entry(entry):
                        output_entries.pop(entry_index)
                        total_lines = len(input_fields) + len(output_entries)
                        if current_line >= total_lines:
                            current_line = total_lines - 1
                    screen_needs_update.set()
        elif c == ord("X") and current_line >= len(input_fields):  # Shift-X to clear all entries
            to_remove = []
            for entry in output_entries:
                if cancel_entry(entry):
                    to_remove.append(entry)
            for entry in to_remove:
                entry_index = output_entries.index(entry)
                output_entries.pop(entry_index)
            screen_needs_update.set()
            total_lines = len(input_fields) + len(output_entries)
            if current_line >= total_lines:
                current_line = total_lines - 1
        elif c == 9:  # Tab key
            total_lines = len(input_fields) + len(output_entries)
            current_line = (current_line + 1) % total_lines
            screen_needs_update.set()
        elif c == ord("r") and current_line >= len(input_fields):
            entry_index = current_line - len(input_fields)
            if entry_index < len(output_entries):
                entry = output_entries[entry_index]
                with entry["lock"]:
                    if entry["status"] in ["Finished", "Error", "Cancelled"]:
                        # Reset the entry to "Waiting" status
                        entry["status"] = "Waiting"
                        entry["output"] = ""
                        entry["speed"] = None
                        entry["ttft"] = None
                        entry["pcc"] = None
                        entry["process"] = None
                        entry["log_file"] = None
                        entry["stop_event"].clear()
                        screen_needs_update.set()
        elif c == ord("m") and current_line >= len(input_fields):
            # Export results to markdown
            export_results_to_markdown(output_entries, stdscr)
            last_drawn_state["current_line"] = -1  # Reset to force redraw
            screen_needs_update.set()
        elif c == ord("p") and current_line >= len(input_fields):
            # Reparse the selected entry's log file
            entry_index = current_line - len(input_fields)
            if entry_index < len(output_entries):
                entry = output_entries[entry_index]
                reparse_log_file(entry, screen_needs_update)
        else:
            if current_line < len(input_fields) and not exiting:
                current_field = current_line
                input_fields[current_field]["value"] += chr(c)
            screen_needs_update.set()


def define_color_pairs():
    # Extended color codes (assuming 256-color support)
    # Muted pastel colors
    COLOR_LIGHT_BLUE = 109  # Light pastel blue
    COLOR_LIGHT_CYAN = 152  # Light pastel cyan
    COLOR_LIGHT_GREEN = 108  # Light pastel green
    COLOR_LIGHT_YELLOW = 229  # Light pastel yellow
    COLOR_LIGHT_RED = 174  # Light pastel red
    COLOR_LIGHT_PURPLE = 183  # Light pastel purple
    COLOR_GRAY = 250  # Light gray
    COLOR_GRAY_DARK = 244  # Dark gray
    COLOR_WHITE = 15  # Bright white
    COLOR_BLACK = 16  # Black

    COLOR_DARK_BLUE = 17  # Dark blue for help bar background
    COLOR_LIGHT_YELLOW = 229  # Light yellow for help bar text

    # Initialize color pairs
    curses.init_pair(1, COLOR_BLACK, COLOR_GRAY)  # Selected field/background
    curses.init_pair(2, COLOR_LIGHT_CYAN, -1)  # Labels
    curses.init_pair(3, COLOR_WHITE, -1)  # Input values
    curses.init_pair(4, COLOR_WHITE, -1)  # Header text
    curses.init_pair(5, COLOR_GRAY, -1)  # 'Waiting' status
    curses.init_pair(6, COLOR_LIGHT_YELLOW, -1)  # 'Running' status
    curses.init_pair(7, COLOR_LIGHT_GREEN, -1)  # 'Finished' status
    curses.init_pair(8, COLOR_LIGHT_RED, -1)  # 'Error' status
    curses.init_pair(9, COLOR_LIGHT_GREEN, -1)  # PCC > 0.99
    curses.init_pair(10, COLOR_LIGHT_YELLOW, -1)  # PCC 0.98-0.99
    curses.init_pair(11, COLOR_LIGHT_RED, -1)  # PCC < 0.98
    curses.init_pair(12, COLOR_GRAY_DARK, -1)  # Accuracy percentages

    # Add a new color pair for the help bar
    curses.init_pair(13, COLOR_LIGHT_CYAN, -1)

    # Store the color pair numbers for use in the rest of the program
    global COLOR_PAIR_SELECTED
    global COLOR_PAIR_LABEL
    global COLOR_PAIR_VALUE
    global COLOR_PAIR_HEADER
    global COLOR_PAIR_WAITING
    global COLOR_PAIR_RUNNING
    global COLOR_PAIR_FINISHED
    global COLOR_PAIR_ERROR
    global COLOR_PAIR_SPEED
    global COLOR_PAIR_PCC_GREEN
    global COLOR_PAIR_PCC_YELLOW
    global COLOR_PAIR_PCC_RED
    global COLOR_PAIR_PCC_ACCURACY
    global COLOR_PAIR_HELP_BAR

    COLOR_PAIR_SELECTED = curses.color_pair(1)
    COLOR_PAIR_LABEL = curses.color_pair(2)
    COLOR_PAIR_VALUE = curses.color_pair(3)
    COLOR_PAIR_HEADER = curses.color_pair(4) | curses.A_BOLD
    COLOR_PAIR_WAITING = curses.color_pair(5)
    COLOR_PAIR_RUNNING = curses.color_pair(6)
    COLOR_PAIR_FINISHED = curses.color_pair(7)
    COLOR_PAIR_ERROR = curses.color_pair(8)
    COLOR_PAIR_SPEED = curses.color_pair(6)  # Use the same color as RUNNING
    COLOR_PAIR_PCC_GREEN = curses.color_pair(9)
    COLOR_PAIR_PCC_YELLOW = curses.color_pair(10)
    COLOR_PAIR_PCC_RED = curses.color_pair(11)
    COLOR_PAIR_PCC_ACCURACY = curses.color_pair(12)
    COLOR_PAIR_HELP_BAR = curses.color_pair(13)


def draw_changes(stdscr, input_fields, output_entries, current_line, last_drawn_state):
    max_y, max_x = stdscr.getmaxyx()

    # Update input fields
    for idx, field in enumerate(input_fields):
        if (
            idx >= len(last_drawn_state["input_fields"])
            or field != last_drawn_state["input_fields"][idx]
            or current_line != last_drawn_state["current_line"]
        ):
            draw_input_field(stdscr, field, idx == current_line, max_x)
            if idx < len(last_drawn_state["input_fields"]):
                last_drawn_state["input_fields"][idx] = field.copy()
            else:
                last_drawn_state["input_fields"].append(field.copy())

    # Draw a divider line
    divider_y = len(input_fields)
    stdscr.hline(divider_y, 0, curses.ACS_HLINE, max_x)

    # Draw header
    header_y = divider_y + 1
    header = format_header(max_x)
    stdscr.addstr(header_y, 0, header, COLOR_PAIR_HEADER)
    stdscr.clrtoeol()

    # Update output entries
    output_start_y = header_y + 1
    for idx, entry in enumerate(output_entries):
        y = output_start_y + idx
        if y >= max_y - 3:
            break

        # Only draw if entry has changed or selection state changed
        if entry.changed or current_line != last_drawn_state["current_line"]:
            draw_output_entry(stdscr, entry, y, current_line == len(input_fields) + idx, max_x)
            entry.mark_drawn()  # Mark as drawn after updating

    # Clear any extra lines if output entries were removed
    for y in range(
        output_start_y + len(output_entries),
        min(output_start_y + len(last_drawn_state["output_entries"]), max_y - 3),
    ):
        stdscr.move(y, 0)
        stdscr.clrtoeol()

    last_drawn_state["current_line"] = current_line
    last_drawn_state["output_entries"] = [{"log_id": entry.log_id} for entry in output_entries]


def draw_input_field(stdscr, field, is_selected, max_x):
    x, y = field["x"], field["y"]
    label, value = field["label"], field["value"]
    if is_selected:
        stdscr.addstr(y, x, label + ": ", COLOR_PAIR_SELECTED)
        stdscr.addstr(y, x + len(label) + 2, value, COLOR_PAIR_SELECTED)
    else:
        stdscr.addstr(y, x, label + ": ", COLOR_PAIR_LABEL)
        stdscr.addstr(y, x + len(label) + 2, value, COLOR_PAIR_VALUE)
    stdscr.clrtoeol()  # Clear the rest of the line


def draw_output_entry(stdscr, entry, y, is_selected, max_x):
    cols = [
        entry.command_name,
        entry.model.split("hf:", 1)[1] if entry.model.startswith("hf:") else entry.model,
        entry.device,
        entry.status,
        entry.speed if entry.speed else "",
        entry.ttft if entry.ttft else "",
        entry.pcc if entry.pcc else "",
        entry.output,
    ]
    col_widths = [20, 20, 10, 20, 10, 10, max_x - 95]

    x = 0
    for i, (col, width) in enumerate(zip(cols, col_widths)):
        col_text = str(col)[:width].ljust(width)
        if is_selected:
            stdscr.addstr(y, x, col_text, COLOR_PAIR_SELECTED)
        else:
            color = curses.color_pair(0)
            if i == 3:  # Status column
                status = entry.status
                if status == "Waiting" or status == "Cancelled":
                    color = COLOR_PAIR_WAITING
                elif status in ["Running", "Initializing device", "Prefill", "Decode", "Starting"] or status.startswith(
                    "Loading "
                ):
                    color = COLOR_PAIR_RUNNING
                elif status == "Finished":
                    color = COLOR_PAIR_FINISHED
                elif status == "Error":
                    color = COLOR_PAIR_ERROR
                elif status == "Terminating" or status == "Resetting":
                    color = COLOR_PAIR_WAITING
            elif i == 4:  # Speed column
                color = COLOR_PAIR_SPEED
            elif i == 5:  # TTFT column
                color = COLOR_PAIR_SPEED
            elif i == 6:  # PCC column
                if col:
                    try:
                        pcc_value = float(col)
                        if pcc_value > 0.99:
                            color = COLOR_PAIR_PCC_GREEN
                        elif 0.98 < pcc_value <= 0.99:
                            color = COLOR_PAIR_PCC_YELLOW
                        else:
                            color = COLOR_PAIR_PCC_RED
                    except ValueError:
                        color = COLOR_PAIR_PCC_ACCURACY
            else:
                color = curses.color_pair(0)
            stdscr.addstr(y, x, col_text, color)
        x += width
    stdscr.clrtoeol()  # Clear the rest of the line


def format_header(max_x):
    cols = ["Command", "Model", "Device", "Status", "Speed", "TTFT(ms)", "PCC", "Output"]
    col_widths = [20, 20, 10, 20, 10, 10, max_x - 95]  # Adjusted widths to accommodate the PCC column
    formatted_cols = []
    for col, width in zip(cols, col_widths):
        formatted_cols.append(col[:width].ljust(width))
    return "".join(formatted_cols)


def parse_list(input_str, allow_space=True):
    if not input_str.strip():
        return [""]
    else:
        if allow_space:
            items = [item.strip() for item in re.split(r"[,\s]+", input_str.strip()) if item.strip()]
        else:
            items = [item.strip() for item in input_str.strip().split(",") if item.strip()]
        return items


def worker_thread_func(output_entries, stop_event, screen_lock, screen_needs_update):
    while not stop_event.is_set():
        running_entry = None
        for entry in output_entries:
            with entry["lock"]:
                # Check if the process is still running (poll() returns None for running processes)
                # or if the entry is in a transitional state
                if (entry["process"] and entry["process"].poll() is None) or entry["status"] in [
                    "Resetting",
                    "Terminating",
                    "Initializing device",
                ]:
                    running_entry = entry
                    break

        if not running_entry:
            for entry in output_entries:
                with entry["lock"]:
                    if entry["status"] == "Waiting":
                        run_entry_command(entry, screen_lock, output_entries, screen_needs_update)
                        break
        # Set screen_needs_update whenever there's a change in output entries
        screen_needs_update.set()
        time.sleep(0.1)


def run_entry_command(entry, screen_lock, output_entries, screen_needs_update):
    entry["status"] = "Initializing device"
    screen_needs_update.set()

    # Set environment variables
    env = os.environ.copy()
    env["MESH_DEVICE"] = entry["device"]
    model_dir = get_dir(entry["model"])
    dir_env = "HF_MODEL"
    if model_dir.startswith("hf:"):
        model_dir = model_dir.split("hf:", 1)[1]
    env[dir_env] = model_dir

    # Open log file
    entry.open_log_file()

    env["ACTUAL_DEVICE"] = get_device()

    # Define command shortcuts
    command_shortcuts = {
        "accuracy": "pytest models/tt_transformers/demo/simple_text_demo.py -k 'performance and ci-token-accuracy'",
        "accuracy-acc": "pytest models/tt_transformers/demo/simple_text_demo.py -k 'accuracy and ci-token-accuracy'",
        "demo": "pytest models/tt_transformers/demo/simple_text_demo.py -k performance-batch-1",
        "demo-acc": "pytest models/tt_transformers/demo/simple_text_demo.py -k accuracy-batch-1",
        "demo-32": "pytest models/tt_transformers/demo/simple_text_demo.py -k performance-batch-32",
        "demo-long": "pytest models/tt_transformers/demo/simple_text_demo.py -k performance-long",
        "demo-ci-1": "pytest models/tt_transformers/demo/simple_text_demo.py -k performance-ci-1",
        "demo-ci-32": "pytest models/tt_transformers/demo/simple_text_demo.py -k performance-ci-32",
        "demo-dp-4": "pytest models/tt_transformers/demo/simple_text_demo.py -k 'performance and DP-4-b1'",
        "demo-dp-8": "pytest models/tt_transformers/demo/simple_text_demo.py -k 'performance and DP-8-b1'",
        "demo-32-dp-4": "pytest models/tt_transformers/demo/simple_text_demo.py -k 'performance and DP-4-b32'",
        "demo-ci-dp-4": "pytest models/tt_transformers/demo/simple_text_demo.py -k 'performance and ci-b1-DP-4'",
        "demo-ci-dp-8": "pytest models/tt_transformers/demo/simple_text_demo.py -k 'performance and ci-b1-DP-8'",
        "attention": "pytest models/tt_transformers/tests/test_attention.py",
        "attention-prefill": "pytest models/tt_transformers/tests/test_attention_prefill.py",
        "mlp": "pytest models/tt_transformers/tests/test_mlp.py",
        "rmsnorm": "pytest models/tt_transformers/tests/test_rms_norm.py",
        "embedding": "pytest models/tt_transformers/tests/test_embedding.py",
        "decoder": "pytest models/tt_transformers/tests/test_decoder.py",
        "decoder-prefill": "pytest models/tt_transformers/tests/test_decoder_prefill.py",
        "lm-head": "pytest models/tt_transformers/tests/test_lm_head.py",
        "model": "pytest models/tt_transformers/tests/test_model.py -k 'performance-256 and full'",
        "model-quick": "pytest models/tt_transformers/tests/test_model.py -k 'performance-256 and quick'",
        "model-prefill": "pytest models/tt_transformers/tests/test_model_prefill.py -k 'all_layers and performance- and 4k'",
        # Precision and Math Fidelity tests for Pareto analysis
        **pareto_commands,
        **pareto_commands_from_json,
        # Vision tests (require 11B weights)
        "vision-mlp": "pytest models/tt_transformers/tests/multimodal/test_image_mlp.py",
        "vision-attn": "pytest models/tt_transformers/tests/multimodal/test_image_attention.py",
        "vision-block": "pytest models/tt_transformers/tests/multimodal/test_image_block.py",
        "vision-xfmr": "pytest models/tt_transformers/tests/multimodal/test_image_transformer.py",
        "vision-xattn": "pytest models/tt_transformers/tests/multimodal/test_cross_attention.py",
        "vision-xblock": "pytest models/tt_transformers/tests/multimodal/test_cross_block.py",
        "vision-conv": "pytest models/tt_transformers/tests/multimodal/test_conv2d_patch.py",
        "vision-class": "pytest models/tt_transformers/tests/multimodal/test_class_embedding.py",
        "vision-tile-pos": "pytest models/tt_transformers/tests/multimodal/test_tile_position_embedding.py",
        "vision-pos": "pytest models/tt_transformers/tests/multimodal/test_positional_embedding.py",
        "vision-encoder": "pytest models/tt_transformers/tests/multimodal/test_vision_encoder.py",
        "vision-text-xfmr": "pytest models/tt_transformers/tests/multimodal/test_cross_attention_transformer_text.py",
        "vision-vision-xfmr": "pytest models/tt_transformers/tests/multimodal/test_cross_attention_transformer_vision.py",
    }

    # Check if the command is a shortcut and replace it if necessary
    command_input = entry["command_input"]
    if command_input in command_shortcuts:
        command_input = command_shortcuts[command_input]
    elif command_input.startswith("pytest "):
        pass  # Run anything you want, bro!
    else:  # If command is invalid, set status to "Error" and return a message to the user with the full list of commands
        entry["status"] = "Error"
        entry["output"] = f"Warning: '{command_input}' is not a valid command. Valid commands are: " + ", ".join(
            command_shortcuts.keys()
        )
        screen_needs_update.set()

    # Prepare the command
    cmd_list = shlex.split(command_input)

    # If the command is invalid, write the output to the log file and return before trying to run the bad command
    if entry["status"] == "Error":
        entry["log_file"].write(entry["output"] + "\n")
        entry["log_file"].flush()
        return

    # Write the command to the log file
    entry["log_file"].write(
        f"MESH_DEVICE={entry['device']} {dir_env}={model_dir} ACTUAL_DEVICE={get_device()} {command_input}" + "\n"
    )

    # Start the subprocess
    entry["process"] = subprocess.Popen(
        cmd_list, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, text=True, preexec_fn=os.setsid
    )

    # Read the output in a separate thread
    entry["thread"] = threading.Thread(
        target=process_output, args=(entry, screen_lock, output_entries, screen_needs_update)
    )
    entry["thread"].daemon = True
    entry["thread"].start()


def process_output(entry, screen_lock, output_entries, screen_needs_update):
    process = entry.process
    log_file = entry.log_file
    previous_line = ""
    try:
        for line in iter(process.stdout.readline, ""):
            # Write to log file
            log_file.write(line)
            log_file.flush()

            # Update status and output based on output
            status, output, speed, ttft, pcc = parse_output_line(line, previous_line, entry.status)
            previous_line = line.strip()

            with entry.lock:
                if status != entry.status or output or speed is not None or ttft is not None or pcc is not None:
                    entry.status = status  # This will mark entry as changed via __setattr__
                    if output:
                        entry.output = output
                    if speed is not None:
                        entry.speed = f"{speed:.1f}"
                    if ttft is not None:
                        entry.ttft = f"{ttft:.0f}"
                    if pcc is not None:
                        try:
                            pcc_value = float(pcc)
                            if entry.pcc is None or pcc_value < float(entry.pcc):
                                entry.pcc = pcc
                        except ValueError:
                            entry.pcc = pcc
                    # Save state whenever process status changes
                    output_entries.save_state()
                    screen_needs_update.set()

            with screen_lock:
                pass  # Screen will be updated in main loop

    finally:
        # Ensure we close the stdout stream
        process.stdout.close()

        # Wait for the process to fully terminate
        process.wait()

        with entry.lock:
            if process.returncode != 0:
                if entry.stop_event.is_set():
                    entry.status = "Cancelled"
                else:
                    exception_name = find_exception_in_log(entry.log_file.name)
                    entry.status = "Error"
                    if exception_name:
                        entry.output = exception_name
                    reset_device_async(entry, screen_lock, screen_needs_update)
            else:
                entry.status = "Finished"
            # Save state when process completes
            output_entries.save_state()
        entry.process = None
        log_file.close()

        screen_needs_update.set()


def parse_output_line(line, previous_line, current_status):
    line = line.strip()

    # Check for speed information
    speed = None
    speed_match = re.search(r"@ (\d+\.\d+) tok/s/user", line)
    if speed_match:
        speed = float(speed_match.group(1))
    else:
        # Check for end_to_end_inference time from perf test
        latency_match = re.search(r"end_to_end_inference: (\d+\.\d+)s", line)
        if latency_match:
            speed = 1000 * float(latency_match.group(1))  # convert to ms

    # Check for TTFT information
    ttft = None
    ttft_match = re.search(r"\(TTFT\)\: (\d+\.\d+)ms", line)
    if ttft_match:
        ttft = float(ttft_match.group(1))

    # Check for PCC information
    pcc = None
    pcc_match = re.search(r"PCC: (\d+\.\d+)", line)
    if pcc_match:
        pcc = f"{float(pcc_match.group(1)):.5f}"
    else:
        # Check for Top-1/Top-5 accuracy format
        acc_match = re.search(r"Top-1: (\d+)% \| Top-5: (\d+)%", line)
        if acc_match:
            top1, top5 = acc_match.groups()
            pcc = f"{top1.strip():<3s}|{top5.strip():>3s}"

    if "Initializing device" in line:
        return "Initializing device", None, speed, ttft, pcc
    elif "Loading weights" in line:
        return "Loading weights", None, speed, ttft, pcc
    elif re.search(r"layers\.\d+\.", line):
        match = re.search(r"layers\.(\d+)\.", line)
        if match:
            layer_number = match.group(1)
            return f"Loading layer {layer_number}", None, speed, ttft, pcc
    elif "Starting inference..." in line:
        return "Starting", None, speed, ttft, pcc
    elif "Starting prefill..." in line:
        return "Prefill", None, speed, ttft, pcc
    elif "Starting decode..." in line:
        return "Decode", None, speed, ttft, pcc
    elif "- OUTPUT" in line:
        return "Waiting for output", None, speed, ttft, pcc
    elif current_status == "Waiting for output" and "- OUTPUT" in previous_line:
        if "<|start_header_id|>assistant<|end_header_id|>" in line:
            output = line.split("<|start_header_id|>assistant<|end_header_id|>", 1)[1].strip()
            if output:
                return "Running", output, speed, ttft, pcc
            else:
                return "Assistant output", None, speed, ttft, pcc  # wait for a non-blank line
        else:
            return "Running", line, speed, ttft, pcc
    elif current_status == "Assistant output" and line:  # skip blank lines
        return "Running", line, speed, ttft, pcc

    # Check for test output
    test_match = re.search(r"\| models\.demos\.llama3\.tests\..+ - (.+)", line)
    if test_match:
        if current_status.startswith("Loading") and (pcc is not None or speed is not None or ttft is not None):
            current_status = "Running"
        return current_status, test_match.group(1), speed, ttft, pcc

    return current_status, None, speed, ttft, pcc


def get_dir(model):
    if model.startswith("hf:"):
        return model

    try:
        # returns path to hf_model snapshot based on HF_HOME or HF_HUB_CACHE
        hf_model_dir = snapshot_download(model, local_files_only=True)
        return hf_model_dir
    except (LocalEntryNotFoundError, HFValidationError):
        pass

    model_dir = {
        "1b": os.environ.get("LLAMA_32_1B_DIR", "meta-llama/Llama-3.2-1B-Instruct"),
        "3b": os.environ.get("LLAMA_32_3B_DIR", "meta-llama/Llama-3.2-3B-Instruct"),
        "8b": os.environ.get("LLAMA_31_8B_DIR", "meta-llama/Llama-3.1-8B-Instruct"),
        "11b": os.environ.get("LLAMA_32_11B_DIR", "meta-llama/Llama-3.2-11B-Vision-Instruct"),
        "11b-b": os.environ.get("LLAMA_32_11B_BASE_DIR", "meta-llama/Llama-3.2-11B-Vision"),
        "70b": os.environ.get("LLAMA_31_70B_DIR", "meta-llama/Llama-3.1-70B-Instruct"),
        "70b-r1": os.environ.get("DEEPSEEK_R1_LLAMA_70B_DIR", "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"),
        "q2-7b": os.environ.get("QWEN_7B_DIR", "Qwen/Qwen2.5-7B-Instruct"),
        "q2-72b": os.environ.get("QWEN_72B_DIR", "Qwen/Qwen2.5-72B-Instruct"),
        "q2-coder-32b": os.environ.get("QWEN_CODER_32B_DIR", "Qwen/Qwen2.5-Coder-32B"),
        "q3-32b": os.environ.get("QWEN_32B_DIR", "Qwen/Qwen3-32B"),
    }.get(model.lower(), "")

    if model_dir and os.path.exists(model_dir):
        return model_dir

    try:
        # returns path to hf_model snapshot based on HF_HOME or HF_HUB_CACHE
        model_dir = snapshot_download(model_dir, local_files_only=True)
        return model_dir
    except (LocalEntryNotFoundError, HFValidationError):
        pass

    print(f"Error: The directory for the {model} model does not exist: {model_dir}")
    print("You can set the following environment variables to specify the correct directory path:")
    print("  - LLAMA_32_1B_DIR for 1b model")
    print("  - LLAMA_32_3B_DIR for 3b model")
    print("  - LLAMA_31_8B_DIR for 8b model")
    print("  - LLAMA_32_11B_DIR for 11b model")
    print("  - LLAMA_31_70B_DIR for 70b model")
    print("  - DEEPSEEK_R1_LLAMA_70B_DIR for DeepSeek R1 Llama 70b distill model")
    print("  - QWEN_7B_DIR for 7b Qwen2.5 model")
    print("  - QWEN_72B_DIR for 72b Qwen2.5 model")
    sys.exit(1)


def get_command_name(command_input):
    # Get command name
    if "pytest" in command_input:
        match = re.search(r"pytest\s+([\S]+)", command_input)
        if match:
            test_file = match.group(1)
            basename = os.path.basename(test_file).split(".")[0]
            command_name = basename
        else:
            command_name = "pytest"
    else:
        cmd = shlex.split(command_input)[0]
        command_name = os.path.basename(cmd)
    return command_name


def find_exception_in_log(log_filename):
    exception_name = None
    with open(log_filename, "r") as f:
        log_lines = f.readlines()
        for line in reversed(log_lines):
            # Check for Python exceptions
            match = re.search(r"(\w+Error):", line)
            if match:
                exception_name = match.group(1)
                break

            # Check for TT_FATAL errors
            tt_fatal_match = re.search(r"TT_FATAL\s*(.+)", line)
            if tt_fatal_match:
                exception_name = tt_fatal_match.group(1).strip()
                break

            # Check for other FATAL errors
            fatal_match = re.search(r"FATAL", line)
            if fatal_match:
                parts = line.split("|", 1)
                if len(parts) > 1:
                    exception_name = parts[1].strip()
                break

            # Check for lines starting with "E   "
            e_match = re.match(r"E\s+(.+)", line)
            if e_match:
                exception_name = e_match.group(1).strip()
                break
    return exception_name


def terminate_process_tree(pid):
    try:
        parent = psutil.Process(pid)
        children = parent.children(recursive=True)
        for child in children:
            os.killpg(os.getpgid(child.pid), signal.SIGTERM)
        os.killpg(os.getpgid(pid), signal.SIGTERM)
    except (psutil.NoSuchProcess, ProcessLookupError):
        pass  # Process already terminated


def generate_tg_config_file():
    hostname = os.environ.get("HOSTNAME", "unknown")
    host_name_config = hostname.split("-special-")[0]  # e.g. aus-glx-NR
    mobo_name_config = (
        host_name_config.split("-")[0]
        + "-"
        + host_name_config.split("-")[1]
        + "-mgmt-"
        + host_name_config.split("-")[2]
    )  # e.g. aus-glx-mgmt-NR
    config_file = os.path.expanduser(f"~/.config/tenstorrent/reset_config_{hostname}.json")
    config = {
        "time": "",
        "host_name": host_name_config,
        "wh_link_reset": {"pci_index": [0, 1, 2, 3]},
        "re_init_devices": True,
        "wh_mobo_reset": [
            {
                "nb_host_pci_idx": [0, 1, 2, 3],
                "mobo": mobo_name_config,
                "disabled_ports": ["0:0", "0:1", "0:2", "1:0", "1:1", "1:2", "6:2", "7:2"],
            }
        ],
    }
    with open(config_file, "w") as f:
        json.dump(config, f)
        f.close()


def reset_device_async(entry, screen_lock, screen_needs_update):
    def reset_thread():
        hostname = os.environ.get("HOSTNAME", "unknown")
        config_file = os.path.expanduser(f"~/.config/tenstorrent/reset_config_{hostname}.json")
        reset_cmd = ["tt-smi", "-r", config_file]
        try:
            result = subprocess.run(reset_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        except subprocess.CalledProcessError as e:
            pass
        finally:
            with entry.lock:
                entry.status = previous_status
            screen_needs_update.set()

    previous_status = entry.status
    entry.status = "Resetting"
    reset_thread = threading.Thread(target=reset_thread)
    reset_thread.daemon = True
    reset_thread.start()


def draw_help_bar(stdscr, current_line, num_input_fields, num_output_entries):
    max_y, max_x = stdscr.getmaxyx()
    help_text = get_help_text(current_line, num_input_fields, num_output_entries)
    stdscr.addstr(max_y - 1, 0, help_text[: max_x - 1], COLOR_PAIR_HELP_BAR)
    stdscr.clrtoeol()


def get_help_text(current_line, num_input_fields, num_output_entries):
    if current_line == 0:
        return (
            "Shortcuts: demo, tests, accuracy, 'help' for full list | Enter: Submit | ↑↓: Navigate fields | Esc: Exit"
        )
    elif current_line == 1:
        return 'New: prefix with "hf:" to run a HF model e.g. "hf:Qwen/Qwen3-4B" | Enter: Submit | ↑↓: Navigate fields | Esc: Exit'
    elif current_line <= num_input_fields - 1:
        return "Enter: Next field | ↑↓: Navigate fields | Esc: Exit"
    else:
        return "Enter: View log | Backspace/x: Cancel entry | X: Cancel all | r: Restart entry | p: Reparse log | m: Export markdown | ↑↓: Navigate entries | Esc: Exit"


def cancel_entry(entry):
    """Handle removal of a single entry, returning True if entry was removed"""
    with entry["lock"]:
        if entry["process"] and entry["process"].poll() is None:
            # Cancel the running process
            entry["stop_event"].set()
            terminate_process_tree(entry["process"].pid)
            entry["status"] = "Terminating"
            # Entry is still running, so don't remove it
            return False
        elif entry["status"] != "Resetting":
            # Safe to remove the entry if it's already cancelled
            return True
    # Entry is running/resetting, so don't remove it
    return False


def export_results_to_markdown(output_entries, stdscr):
    # Initialize ordered lists to maintain entry order
    perf_entries = []
    acc_entries = []
    dmf_entries = {}

    # Collect results from entries in their original order
    for entry in output_entries:
        if entry.status in ["Finished", "Error"]:
            key = (entry.model, entry.device)

            if entry.command_name == "demo" or entry.command_name == "accuracy":
                # Get speed and ttft from demo entry
                speed = entry.speed if entry.command_name == "demo" else None
                ttft = entry.ttft if entry.command_name == "demo" else None
                # Get accuracy from accuracy entry
                top1, top5 = "N/A", "N/A"
                if entry.command_name == "accuracy" and entry.pcc:
                    match = re.match(r"(\d+)\s*\|\s*(\d+)", entry.pcc)
                    if match:
                        top1, top5 = match.group(1), match.group(2)

                # Find existing entry or create new one
                existing_entry = next((e for e in perf_entries if e[0] == key), None)
                if existing_entry:
                    if speed:
                        existing_entry[3] = speed
                    if top1 != "N/A":
                        existing_entry[1:3] = [top1, top5]
                    if ttft:
                        existing_entry[4] = ttft
                else:
                    perf_entries.append([key, top1, top5, speed or "N/A", ttft])

            elif entry.command_name == "demo-acc" or entry.command_name == "accuracy-acc":
                # Same logic for accuracy configuration
                speed = entry.speed if entry.command_name == "demo-acc" else None
                ttft = entry.ttft if entry.command_name == "demo-acc" else None
                top1, top5 = "N/A", "N/A"
                if entry.command_name == "accuracy-acc" and entry.pcc:
                    match = re.match(r"(\d+)\s*\|\s*(\d+)", entry.pcc)
                    if match:
                        top1, top5 = match.group(1), match.group(2)

                existing_entry = next((e for e in acc_entries if e[0] == key), None)
                if existing_entry:
                    if speed:
                        existing_entry[3] = speed
                    if top1 != "N/A":
                        existing_entry[1:3] = [top1, top5]
                    if ttft:
                        existing_entry[4] = ttft
                else:
                    acc_entries.append([key, top1, top5, speed or "N/A", ttft])

            elif entry.command_name.startswith("demo-dmf"):
                # get speed from demo-precision entries
                speed = entry.speed
                ttft = entry.ttft

                match = re.search(r"demo-dmf(?:-json)?-(\d+)", entry.command_name)
                if not match:
                    continue
                setting_idx = int(match.group(1))
                if setting_idx not in dmf_entries:
                    dmf_entries[setting_idx] = {key: ["N/A", "N/A", speed or "N/A", ttft]}
                else:
                    dmf_entries[setting_idx][key][2] = speed or "N/A"
                    dmf_entries[setting_idx][key][3] = ttft or "N/A"

            elif entry.command_name.startswith("accuracy-dmf"):
                # get accuracy from accuracy-precision entries
                top1, top5 = "N/A", "N/A"
                if entry.pcc:
                    match = re.match(r"(\d+)\s*\|\s*(\d+)", entry.pcc)
                    if match:
                        top1, top5 = match.group(1), match.group(2)

                match = re.search(r"accuracy-dmf(?:-json)?-(\d+)", entry.command_name)
                if not match:
                    continue
                setting_idx = int(match.group(1))
                if setting_idx not in dmf_entries:
                    dmf_entries[setting_idx] = {key: [top1, top5, "N/A", "N/A"]}
                else:
                    dmf_entries[setting_idx][key][0:2] = [top1, top5]

    # Create markdown content
    fullname = {
        "1b": "Llama3.2-1B",
        "3b": "Llama3.2-3B",
        "8b": "Llama3.1-8B",
        "11b": "Llama3.2-11B",
        "70b": "Llama3.1-70B",
        "70b-r1": "DeepSeek-R1-Llama-70B",
        "q2-7b": "Qwen2.5-7B",
        "q2-72b": "Qwen2.5-72B",
        "q2-coder-32b": "Qwen2.5-Coder-32B",
        "q3-32b": "Qwen3-32B",
    }

    markdown_lines = []
    if perf_entries:
        markdown_lines.extend(
            [
                "## Performance",
                "",
                "This configuration uses bfp4 MLP FF1+FF3 for all models.",
                "",
                "| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | TTFT (ms) |",
                "|-------|--------|-----------|-----------|---------------|-----------|",
            ]
        )
        # Add rows for performance table in original order
        for entry in perf_entries:
            (model, device), top1, top5, speed, ttft = entry
            markdown_lines.append(f"| {fullname[model]} | {device} | {top1} | {top5} | {speed} | {ttft} |")

    # Add accuracy table
    if acc_entries:
        markdown_lines.extend(
            [
                "",
                "## Accuracy",
                "",
                "This configuration uses bfp4 MLP FF1+FF3 only for the 3.1-70B model.",
                "",
                "| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | TTFT (ms) |",
                "|-------|--------|-----------|-----------|---------------|-----------|",
            ]
        )

        # Add rows for accuracy table in original order
        for entry in acc_entries:
            (model, device), top1, top5, speed, ttft = entry
            markdown_lines.append(f"| {fullname[model]} | {device} | {top1} | {top5} | {speed} | {ttft} |")

    # add precision tables
    def gen_description(setting_idx: int) -> str:
        return parse_optimizations(dtype_mf_settings[setting_idx]).__name__

    if dmf_entries:
        data_to_plot = {}
        for setting_idx, entries in dmf_entries.items():
            # Check if any entry used this setting_idx and came from JSON
            used_json_for_setting = any(f"json-{setting_idx}" in e.command_name for e in output_entries)
            desc = "" if used_json_for_setting else gen_description(setting_idx)
            markdown_lines.extend(
                [
                    "",
                    f"## {desc}",
                    "",
                    "| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | TTFT (ms) |",
                    "|-------|--------|-----------|-----------|---------------|-----------|",
                ]
            )
            for key, entry in entries.items():
                model, device = key
                top1, top5, speed, ttft = entry
                markdown_lines.append(f"| {fullname[model]} | {device} | {top1} | {top5} | {speed} | {ttft} |")

                model_device_name = fullname[model] + "_" + device
                if model_device_name in data_to_plot:
                    data_to_plot[model_device_name][0].append(float(speed))
                    data_to_plot[model_device_name][1].append(float(top1))
                    data_to_plot[model_device_name][2].append(desc)
                else:
                    data_to_plot[model_device_name] = ([float(speed)], [float(top1)], [desc])

        # plot the data between speed and top1 for each mode-device pair
        for model_device_name, data in data_to_plot.items():
            speed, top1, desc = data
            plot_data_with_desc(speed, top1, "Speed (t/s/u)", "Top-1 (%)", desc, model_device_name)

    # Check if any entry came from a JSON config
    used_json = any("json" in entry.command_name for entry in output_entries)

    # Only update PERF.md and display it if not using JSON
    if not used_json:
        with open("PERF.md", "w") as f:
            f.write("\n".join(markdown_lines) + "\n")

        # Clear screen and show message
        stdscr.clear()

        # display 'PERF.md' using less
        # Save current terminal state
        curses.def_prog_mode()
        # Exit curses temporaril
        curses.endwin()
        # Run less command
        os.system(f"less -R PERF.md")
        # Resume curses
        curses.reset_prog_mode()
        stdscr.refresh()

        # display helpful message at the bottom of the screen
        stdscr.addstr(0, 0, f"Table written to {os.path.abspath('PERF.md')}")
    else:
        stdscr.clear()

    stdscr.addstr(1, 0, f"Plots written to .svg files in {os.path.abspath('data_plots/')}")
    stdscr.addstr(2, 0, "Press any key to return...")
    stdscr.refresh()

    # Temporarily make getch() blocking
    stdscr.nodelay(False)

    # Wait for a key press and flush input buffer
    stdscr.getch()
    curses.flushinp()

    # Restore non-blocking mode
    stdscr.nodelay(True)


def reparse_log_file(entry, screen_needs_update):
    """Reparse an entry's log file to update speed, ttft and pcc values."""
    try:
        with open(entry.get_log_filename(), "r") as f:
            previous_line = ""
            status = entry.status  # Preserve the current status

            # Reset speed, ttft and pcc before reparsing
            entry.speed = None
            entry.ttft = None
            entry.pcc = None

            for line in f:
                new_status, output, speed, ttft, pcc = parse_output_line(line, previous_line, status)
                previous_line = line.strip()

                if speed is not None:
                    entry.speed = f"{speed:.1f}"
                if ttft is not None:
                    entry.ttft = f"{ttft:.0f}"
                if pcc is not None:
                    try:
                        pcc_value = float(pcc)
                        if entry.pcc is None or pcc_value < float(entry.pcc):
                            entry.pcc = pcc
                    except ValueError:
                        entry.pcc = pcc
                if output:
                    entry.output = output

            screen_needs_update.set()

    except FileNotFoundError:
        pass  # Log file doesn't exist


def get_pareto_front(objective1: np.ndarray, objective2: np.ndarray) -> np.ndarray:
    # Compute the Pareto front where larger is better for both objectives.
    # A point is Pareto optimal if no other point is strictly better in both objectives.
    # One simple method for 2D data: sort by objective1 from largest to smallest
    # and then select points that have the largest objective2 value seen so far.
    indices = np.argsort(objective1)[::-1]

    pareto_indices = []
    current_best = -1.0  # smaller than both objectives' smallest possible values
    for idx in indices:
        if objective2[idx] > current_best:
            pareto_indices.append(idx)
            current_best = objective2[idx]
    return np.array(pareto_indices)


def plot_data_with_desc(objective1, objective2, objective1_label, objective2_label, point_desc, model_device_name):
    # Plot all data points but highlight the Pareto points
    objective1 = np.array(objective1)
    objective2 = np.array(objective2)
    plt.scatter(objective1, objective2, label="Data Points", color="blue")
    # add point_desc to all the points near the bottom of the plot
    for i, (x, y) in enumerate(zip(objective1, objective2)):
        plt.annotate(f"{i}", (x, y), textcoords="offset points", xytext=(0, -10), ha="center")
    # Get one Pareto point for each objective
    pareto_indices = get_pareto_front(objective1, objective2)
    pareto_obj1 = objective1[pareto_indices]
    pareto_obj2 = objective2[pareto_indices]
    # Plot the Pareto front as a connected red line.
    plt.plot(pareto_obj1, pareto_obj2, "r-", label="Pareto Front")
    plt.scatter(pareto_obj1, pareto_obj2, color="red")

    # write point_desc to a csv file in the same directory as the plot
    os.makedirs("data_plots", exist_ok=True)
    with open(f"data_plots/{model_device_name}_point_desc.csv", "w") as f:
        f.write(f"index,{objective1_label},{objective2_label},point_desc\n")
        for i, (x, y) in enumerate(zip(objective1, objective2)):
            f.write(f"{i},{x},{y},{point_desc[i]}\n")

    plt.xlabel(objective1_label)
    plt.ylabel(objective2_label)
    model_name = model_device_name.split("_")[0]
    device_name = model_device_name.split("_")[1]
    plt.title(f"Pareto Analysis between {objective1_label} and {objective2_label}\n of {model_name} on {device_name}")
    plt.legend()

    # Get figure size
    fig = plt.gcf()
    fig_width, fig_height = fig.get_size_inches()
    height_incr = 0.7
    height_incr_total = height_incr * len(pareto_indices)
    new_fig_height = fig_height + height_incr_total
    fig.set_size_inches(fig_width, new_fig_height)

    # Add more bottom margin
    plt.subplots_adjust(bottom=height_incr_total / fig_height)

    # Calculate x,y position for footnote relative to figure size
    # relative to the bottom left corner of the plot with normalized coordinates (0-1)
    x = 0.1
    y = height_incr / 4

    # add footnote to the plot only when json is not used
    for i in pareto_indices:
        desc_splits = point_desc[i].split("}")
        if len(desc_splits) >= 2:
            fidelity_desc = desc_splits[1][2:] + "}"  # 2: slice to remove the leading ", "
            # find substring "li_qkv_prefill" in fidelity_desc and insert a new line before it so the text fits in the plot
            idx = fidelity_desc.find("li_qkv_prefill")
            assert idx != -1, f"li_qkv_prefill not found in {fidelity_desc}"
            fidelity_desc_part2 = " " * len("fidelity_cfg = {") + fidelity_desc[idx:]
            plt.figtext(
                x, y, f"{fidelity_desc_part2}", ha="left", transform=fig.dpi_scale_trans, family="monospace", fontsize=6
            )
            y += height_incr / 8
            fidelity_desc_part1 = fidelity_desc[:idx]
            plt.figtext(
                x, y, f"{fidelity_desc_part1}", ha="left", transform=fig.dpi_scale_trans, family="monospace", fontsize=6
            )
            y += height_incr / 8
            dtype_desc = desc_splits[0] + "}"
            plt.figtext(x, y, f"{dtype_desc}", ha="left", transform=fig.dpi_scale_trans, family="monospace", fontsize=6)
            y += height_incr / 8
            plt.figtext(
                x,
                y,
                f"{i}: {objective1_label}={objective1[i]}, {objective2_label}={objective2[i]}",
                ha="left",
                transform=fig.dpi_scale_trans,
                fontsize=10,
            )
            y += height_incr / 3

    x = fig_width / 2
    plt.figtext(
        x,
        y,
        f"(use data-point id for details in {model_device_name}_point_desc.csv)",
        ha="center",
        transform=fig.dpi_scale_trans,
        fontsize=10,
    )
    plt.grid(True)
    plt.savefig(f"data_plots/{model_device_name}_pareto.svg")
    plt.close()


if __name__ == "__main__":
    os.environ["TERM"] = "xterm-256color"

    ensure_less_installed()
    ensure_ttsmi_installed()
    curses.wrapper(main)
